// DllNanny.cpp

/*
Configure the following registry entries:

[HKEY_CLASSES_ROOT\CLSID\{10000002-0000-0000-0000-000000000001}]
"AppID"="{10000002-0000-0000-0000-000000000001}"

[HKEY_CLASSES_ROOT\AppID\{10000002-0000-0000-0000-000000000001}]
@="Inside DCOM Sample"
"DllSurrogate"="C:\\MyDirectory\\DllNanny.exe"

  or use the following line to use the system-supplied default surrogate (dllhost.exe)

"DllSurrogate"=""


  The surrogate.reg file has a template for these registry entries

*/

#define _WIN32_DCOM
#include <windows.h>
#include <iostream.h>  // For cout
#include <conio.h>     // For _getch

HANDLE g_hEvent;

class CGenericFactory;
class CSurrogate : public ISurrogate
{
public:
	// IUnknown
	ULONG __stdcall AddRef();
	ULONG __stdcall Release();
	HRESULT __stdcall QueryInterface(REFIID riid, void** ppv);

	// ISurrogate
	HRESULT __stdcall LoadDllServer(REFCLSID rclsid);
	HRESULT __stdcall FreeSurrogate();

private:
	CGenericFactory* m_pcf;
};

class CGenericFactory : public IClassFactory
{
public:
	// IUnknown
	ULONG __stdcall AddRef();
	ULONG __stdcall Release();
	HRESULT __stdcall QueryInterface(REFIID iid, void** ppv);

	// IClassFactory
	HRESULT __stdcall CreateInstance(IUnknown *pUnknownOuter, REFIID iid, void** ppv);
	HRESULT __stdcall LockServer(BOOL bLock);

	CGenericFactory(REFCLSID rclsid) : m_cRef(0) { m_clsid = rclsid; }
	~CGenericFactory() { }

private:
	ULONG m_cRef;
	DWORD m_dwRegister;
	friend class CSurrogate;
	CLSID m_clsid;
};

ULONG CSurrogate::AddRef()
{
	return 2;
}

ULONG CSurrogate::Release()
{
	return 1;
}

HRESULT CSurrogate::QueryInterface(REFIID riid, void** ppv)
{
	if(riid == IID_IUnknown)
		*ppv = reinterpret_cast<IUnknown*>(this);
	else if(riid == IID_ISurrogate)
		*ppv = (ISurrogate*)this;
	else
	{
		*ppv = NULL;
		return E_NOINTERFACE;
	}
	AddRef();
	return S_OK;
}

HRESULT CSurrogate::LoadDllServer(REFCLSID rclsid)
{
	cout << "Surrogate: ISurrogate::LoadDllServer" << endl;
	m_pcf = new CGenericFactory(rclsid);
	return CoRegisterClassObject(rclsid, (IClassFactory*)m_pcf, CLSCTX_LOCAL_SERVER, REGCLS_SURROGATE, &m_pcf->m_dwRegister);
}

HRESULT CSurrogate::FreeSurrogate()
{
	cout << "Surrogate: ISurrogate::FreeSurrogate" << endl;
	HRESULT hr = CoRevokeClassObject(m_pcf->m_dwRegister);
	SetEvent(g_hEvent);
	return hr;
}

ULONG CGenericFactory::AddRef()
{
	cout << "Surrogate: CGenericFactory::AddRef() m_cRef = " << m_cRef + 1 << endl;
	return ++m_cRef;
}

ULONG CGenericFactory::Release()
{
	cout << "Surrogate: CGenericFactory::Release() m_cRef = " << m_cRef - 1 << endl;
	if(--m_cRef != 0)
		return m_cRef;
	delete this;
	return 0;
}

HRESULT CGenericFactory::QueryInterface(REFIID riid, void** ppv)
{
	if(riid == IID_IUnknown || riid == IID_IClassFactory)
		*ppv = (IClassFactory*)this;
	else
	{
		*ppv = NULL;
		return E_NOINTERFACE;
	}
	AddRef();
	return S_OK;
}

HRESULT CGenericFactory::CreateInstance(IUnknown *pUnknownOuter, REFIID riid, void** ppv)
{
	return CoCreateInstance(m_clsid, pUnknownOuter, CLSCTX_INPROC_SERVER, riid, ppv);
}

HRESULT CGenericFactory::LockServer(BOOL bLock)
{
	return S_OK;
}

void main(int argc, char** argv)
{
	if(argc < 2)
	{
		cout << "DllNanny must be properly registered as a DLL Surrogate" << endl;
		return;
	}

	cout << "Surrogate: CoInitializeEx()" << endl;
	CoInitializeEx(NULL, COINIT_MULTITHREADED);

	CSurrogate surrogate;
	CoRegisterSurrogate(&surrogate);

	cout << "Surrogate: The DLL's CLSID is " << argv[1] << endl;

	OLECHAR wszCLSID[39];
	mbstowcs(wszCLSID, argv[1], 39);

	CLSID clsid;
	CLSIDFromString(wszCLSID, &clsid);

	surrogate.LoadDllServer(clsid);

	g_hEvent = CreateEvent(NULL, FALSE, FALSE, NULL);
	WaitForSingleObject(g_hEvent, INFINITE);
	CloseHandle(g_hEvent);

	CoUninitialize();

	_getch();
}